import sys
import numpy as np
import matplotlib
import argparse
import matplotlib
from matplotlib import pyplot as plt
from sklearn.calibration import calibration_curve

label_fs=16
matplotlib.rcParams['xtick.labelsize'] = 14
matplotlib.rcParams['ytick.labelsize'] = 14

# import seaborn as sns

# ----------------------------------------------------------------------------

parser = argparse.ArgumentParser()

parser.add_argument('--cal', help='path to experiment results', required=True)
parser.add_argument('--hist', help='path to experiment results', required=True)
parser.add_argument('--metrics', help='path to experiment results', required=True)
parser.add_argument('--plots', help='path to experiment results', required=True)
parser.add_argument('--out', help='path to figure output', required=True)

args = parser.parse_args()

# ----------------------------------------------------------------------------


# files = [('data/semparse-cal.txt', 
# 				'data/semparse-hist.txt',
# 				'data/semparse-metrics.txt',
# 				'data/semparse-plots.txt')]

# hist_files = ['data/multiclass-crf-calibration-bestclass-hist.data', 
# 			  'data/chain-crf-calibration-bestclass-hist.data',
# 			  'data/graph-crf-calibration-hist.data']

# metrics_files = ['data/multiclass-crf-calibration-bestclass-metrics.data', 
# 				'data/chain-crf-calibration-bestclass-metrics.data',
# 				'data/graph-crf-calibration-metrics.data']

# titles = ['Image classification (Multi-class MAP recal.);\n75% accuracy on raw uncalibrated SVM', 
# 				'OCR (Chain CRF MAP recalibration);\n45% per-word accuracy using Viterbi decoding',
# 				'Scene understanding (Graph CRF marginal recal.);\n78% accuracy using mean-field marg. decoding']

def _bucket_sizes(p, n_bins=10):
  lengths = list()
  iv_size = 1./n_bins
  for i in xrange(n_bins):
    l = len([p_j for p_j in p if i*iv_size <= p_j <= (i+1)*iv_size])
    if l:
      lengths.append(l)
  return lengths

label_dict = {'uncalibrated': 'raw', 'calibrated': 'cal', 'one-vs-all': '1-vs-a'}

# ----------------------------------------------------------------------------
# load data

# for i, (plot_file, hist_file, metrics_file, title) in enumerate(zip(plot_files, hist_files, metrics_files, titles)):
# for i, (cal_file, hist_file, metrics_file, plot_file) in enumerate(files):
cf = open(args.cal)
hf = open(args.hist)
mf = open(args.metrics)
pf = open(args.plots)

cal_data = list()
for line in cf:
	fields = line.strip().split()
	label = fields[0]
	proba = list()
	for fi in fields[1:]:
		p1, p2, = fi.split(',')
		proba.append((float(p1),float(p2)))
	cal_data.append((label, zip(*proba)))
cf.close()

hist_data = list()
for line in hf:
	fields = line.strip().split()
	label = fields[0]
	proba = [float(p) for p in fields[1:]]
	hist_data.append((label, proba))
hf.close()

metrics_data = dict()
for line in mf:
	fields = line.strip().split()
	label = fields[0]
	if label not in metrics_data: metrics_data[label] = dict()
	k, v = fields[1].split('\\')
	metrics_data[label][k] = float(v)
mf.close()

plot_data = list()
for line in pf:
	name, prob_str = line.strip().split()
	probs = [float(p) for p in prob_str.split(',')]
	plot_data.append((name, probs))

# ----------------------------------------------------------------------------
# generate figure

matplotlib.rcParams.update({'font.size': 10})
if args.out.startswith('wtccc'):
	plt.figure(figsize=(8,3.5))
elif args.out.startswith('semparse'):
	plt.figure(figsize=(8,3.5))
else:
	plt.figure(figsize=(5,4))

# plt.title(title)

# regret w.r.t. p_uncal
# plt.subplot(221)
plt.subplot2grid((2,7),(0,0),rowspan=1, colspan=3)
plt.title('(a) Accuracy', fontsize=15)
T = len(plot_data[0][1])
if args.out.startswith('semparse'):
  plt.plot(range(T), plot_data[0][1])
  plt.plot(range(T), plot_data[1][1])
  # plt.plot(range(T), plot_data[2][1])
  plt.ylim([0.1,0.25])
  plt.yticks([0.1,0.15,0.2,0.25])
  # plt.locator_params(n_bins=2,axis='y')
elif args.out.startswith('wtccc'):
  plt.plot(range(T), plot_data[0][1])
  plt.plot(range(T), plot_data[1][1])
  plt.ylim([0.0,0.3])
  plt.yticks([0.0,0.1,0.2,0.3])
  plt.xticks([0,1000,2000,3000])
else:
  plt.plot(range(T), plot_data[0][1])
  plt.plot(range(T), plot_data[1][1])
  plt.ylim([-0.02,0.05])
plt.ylabel('L2 loss', fontsize=label_fs)
plt.xlabel('Time', fontsize=label_fs)


# convergence rate of calibration
# plt.subplot(222)
plt.subplot2grid((2,7),(1,0),rowspan=1,colspan=3)
plt.title('(b) Calibration', fontsize=15)
T = len(plot_data[1][1])
# plt.legend(loc="upper right")
if args.out.startswith('semparse'):
  # plt.plot(range(T), plot_data[2][1], label='Uncalibrated')
  # plt.plot(range(T), plot_data[3][1], label='Calibrated')
  plt.plot(range(T), plot_data[3][1], label='Uncalibrated')
  plt.plot(range(T), plot_data[4][1], label='Calibrated')
  # plt.plot(range(T), plot_data[5][1], label='Platt scaling')
  plt.ylim([0,0.06])
  plt.yticks([0,0.015, 0.03, 0.045, 0.06])
if args.out.startswith('wtccc'):
  plt.plot(range(T), plot_data[2][1], label='Uncalibrated')
  plt.plot(range(T), plot_data[3][1], label='Calibrated')
  plt.ylim([0,0.03])
  plt.yticks([0,0.01, 0.02, 0.03])
  plt.xticks([0,1000,2000,3000])
plt.ylabel('Cal. error', fontsize=label_fs)
plt.xlabel('Time', fontsize=label_fs)

# calibration curve
# plt.subplot(212)
plt.subplot2grid((2,7),(0,3),colspan=4,rowspan=2)
plt.title('(c) Final calibration curve', fontsize=15)

label_dict = {0 : 'Uncalibrated', 1 : 'Recalibrated', 2 : 'Platt scaling'}
colors = ('blue', 'green', 'red')
for j, (label, y_p_pair) in enumerate(cal_data):
    p_emp, p_true = y_p_pair
    sizes = np.array(_bucket_sizes(hist_data[j][1]))
    S = np.sum(sizes)
    plt.scatter(p_true, p_emp, s=200.*sizes/S, c=colors[j], edgecolors=colors[j])
    # full_label = label_dict[label] + ' ' + ('(%2.1f)' % (100*metrics_data[label]['brier-score']))
    full_label = label_dict[j]
    plt.plot(p_true, p_emp, '-', label=full_label)
    # ax2.hist(hist_data[j][1], range=(0, 1), bins=10, histtype="step", lw=2, label=label)

plt.plot(np.linspace(0,1,20), np.linspace(0,1,20), color='gray')   

plt.ylim([-0.05, 1.05])

plt.legend(loc="lower right")

plt.xlabel("Predicted probability", fontsize=label_fs)
plt.ylabel('Observed probability', fontsize=label_fs)
plt.xlim((0,1))
# ax2.set_ylabel("Count")
# # ax2.legend(loc="upper center", ncol=2)

plt.tight_layout()
# plt.show()
plt.savefig(args.out, bbox_inches='tight')
